{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ur8xi4C7S06n"
      },
      "outputs": [],
      "source": [
        "# Copyright 2024 Google LLC\n",
        "#\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "#     https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JAPoU8Sm5E6e"
      },
      "source": [
        "# Hugging Face DLCs: Serving Gemma with Text Generation Inference (TGI) on Vertex AI\n",
        "\n",
        "<table align=\"left\">\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://colab.research.google.com/github/GoogleCloudPlatform/generative-ai/blob/main/open-models/serving/vertex_ai_text_generation_inference_gemma.ipynb\">\n",
        "      <img width=\"32px\" src=\"https://www.gstatic.com/pantheon/images/bigquery/welcome_page/colab-logo.svg\" alt=\"Google Colaboratory logo\"><br> Open in Colab\n",
        "    </a>\n",
        "  </td>\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fgenerative-ai%2Fmain%2Fopen-models%2Fserving%2Fvertex_ai_text_generation_inference_gemma.ipynb\">\n",
        "      <img width=\"32px\" src=\"https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN\" alt=\"Google Cloud Colab Enterprise logo\"><br> Open in Colab Enterprise\n",
        "    </a>\n",
        "  </td>\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/generative-ai/main/open-models/serving/vertex_ai_text_generation_inference_gemma.ipynb\">\n",
        "      <img src=\"https://www.gstatic.com/images/branding/gcpiconscolors/vertexai/v1/32px.svg\" alt=\"Vertex AI logo\"><br> Open in Vertex AI Workbench\n",
        "    </a>\n",
        "  </td>\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://github.com/GoogleCloudPlatform/generative-ai/blob/main/open-models/serving/vertex_ai_text_generation_inference_gemma.ipynb\">\n",
        "      <img width=\"32px\" src=\"https://raw.githubusercontent.com/primer/octicons/refs/heads/main/icons/mark-github-24.svg\" alt=\"GitHub logo\"><br> View on GitHub\n",
        "    </a>\n",
        "  </td>\n",
        "</table>\n",
        "\n",
        "<div style=\"clear: both;\"></div>\n",
        "\n",
        "<b>Share to:</b>\n",
        "\n",
        "<a href=\"https://www.linkedin.com/sharing/share-offsite/?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/open-models/serving/vertex_ai_text_generation_inference_gemma.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/8/81/LinkedIn_icon.svg\" alt=\"LinkedIn logo\">\n",
        "</a>\n",
        "\n",
        "<a href=\"https://bsky.app/intent/compose?text=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/open-models/serving/vertex_ai_text_generation_inference_gemma.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/7/7a/Bluesky_Logo.svg\" alt=\"Bluesky logo\">\n",
        "</a>\n",
        "\n",
        "<a href=\"https://twitter.com/intent/tweet?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/open-models/serving/vertex_ai_text_generation_inference_gemma.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/5/5a/X_icon_2.svg\" alt=\"X logo\">\n",
        "</a>\n",
        "\n",
        "<a href=\"https://reddit.com/submit?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/open-models/serving/vertex_ai_text_generation_inference_gemma.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://redditinc.com/hubfs/Reddit%20Inc/Brand/Reddit_Logo.png\" alt=\"Reddit logo\">\n",
        "</a>\n",
        "\n",
        "<a href=\"https://www.facebook.com/sharer/sharer.php?u=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/open-models/serving/vertex_ai_text_generation_inference_gemma.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/5/51/Facebook_f_logo_%282019%29.svg\" alt=\"Facebook logo\">\n",
        "</a>            "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "83b98b0ba19c"
      },
      "source": [
        "![Hugging Face x Google Cloud](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/google-cloud/thumbnail.png)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "84f0f73a0f76"
      },
      "source": [
        "| | |\n",
        "|-|-|\n",
        "| Author(s) | [Alvaro Bartolome](https://github.com/alvarobartt), [Philipp Schmid](https://github.com/philschmid), [Simon Pagezy](https://github.com/pagezyhf), and [Jeff Boudier](https://github.com/jeffboudier) |"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ccd500ae19b5"
      },
      "source": [
        "## Overview"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tvgnzT1CKxrO"
      },
      "source": [
        "> [**Gemma**](https://ai.google.dev/gemma) is a family of lightweight, state-of-the-art open models built from the same research and technology used to create the Gemini models, developed by Google DeepMind and other teams across Google.\n",
        "\n",
        "> [**Text Generation Inference (TGI)**](https://github.com/huggingface/text-generation-inference) is a toolkit developed by Hugging Face for deploying and serving LLMs, with high performance text generation.\n",
        "\n",
        "> [**Hugging Face DLCs**](https://github.com/huggingface/Google-Cloud-Containers) are pre-built and optimized Deep Learning Containers (DLCs) maintained by Hugging Face and Google Cloud teams to simplify environment configuration for your ML workloads.\n",
        "\n",
        "> [**Google Vertex AI**](https://cloud.google.com/vertex-ai) is a Machine Learning (ML) platform that lets you train and deploy ML models and AI applications, and customize large language models (LLMs) for use in your AI-powered applications.\n",
        "\n",
        "This notebook showcases how to deploy Google Gemma from the Hugging Face Hub on Vertex AI using the Hugging Face Deep Learning Container (DLC) for Text Generation Inference (TGI).\n",
        "\n",
        "By the end of this notebook, you will learn how to:\n",
        "\n",
        "- Register any LLM from the Hugging Face Hub on Vertex AI\n",
        "- Deploy an LLM on Vertex AI\n",
        "- Send online predictions on Vertex AI"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "61RBz8LLbxCR"
      },
      "source": [
        "## Get started"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "No17Cw5hgx12"
      },
      "source": [
        "### Install Vertex AI SDK and other required packages\n",
        "\n",
        "To run this example, you will only need the [`google-cloud-aiplatform`](https://github.com/googleapis/python-aiplatform) Python SDK and the [`huggingface_hub`](https://github.com/huggingface/huggingface_hub) Python package."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tFy3H3aPgx12"
      },
      "outputs": [],
      "source": [
        "%pip install --upgrade --user --quiet google-cloud-aiplatform huggingface_hub"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "R5Xep4W9lq-Z"
      },
      "source": [
        "### Restart runtime (Colab only)\n",
        "\n",
        "To use the newly installed packages in this Jupyter environment, if you are on Colab you must restart the runtime. You can do this by running the cell below, which restarts the current kernel. The restart might take a minute or longer. After it's restarted, continue to the next step."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XRvKdaPDTznN"
      },
      "outputs": [],
      "source": [
        "# Automatically restart kernel after installs so that your environment can access the new packages\n",
        "# import IPython\n",
        "\n",
        "# app = IPython.Application.instance()\n",
        "# app.kernel.do_shutdown(True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SbmM4z7FOBpM"
      },
      "source": [
        "<div class=\"alert alert-block alert-warning\">\n",
        "<b>⚠️ The kernel is going to restart. Wait until it's finished before continuing to the next step. ⚠️</b>\n",
        "</div>\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dmWOrTJ3gx13"
      },
      "source": [
        "### Authenticate your Google Cloud account\n",
        "\n",
        "Depending on your Jupyter environment, you may have to manually authenticate. Follow the relevant instructions below."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9cc332aee5fd"
      },
      "source": [
        "**1. Vertex AI Workbench**\n",
        "\n",
        "* Do nothing as you are already authenticated."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "54c81e1a63e6"
      },
      "source": [
        "**2. Local JupyterLab instance, uncomment and run:**"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "c93d96bc9c43"
      },
      "outputs": [],
      "source": [
        "# !gcloud auth login"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "d3e571ce6c56"
      },
      "source": [
        "**3. Colab, uncomment and run:**"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "984a0526fb68"
      },
      "outputs": [],
      "source": [
        "# from google.colab import auth\n",
        "# auth.authenticate_user()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "f9af3e57f89a"
      },
      "source": [
        "### Authenticate your Hugging Face account\n",
        "\n",
        "As [`google/gemma-7b-it`](https://huggingface.co/google/gemma-7b-it) is a gated model, you need to have a Hugging Face Hub account, and accept the Google's usage license for Gemma. Once that's done, you need to generate a new user access token with read-only access so that the weights can be downloaded from the Hub in the Hugging Face DLC for TGI.\n",
        "\n",
        "> Note that the user access token can only be generated via [the Hugging Face Hub UI](https://huggingface.co/settings/tokens/new), where you can either select read-only access to your account, or follow the recommendations and generate a fine-grained token with read-only access to [`google/gemma-7b-it`](https://huggingface.co/google/gemma-7b-it)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4c31c7272804"
      },
      "source": [
        "Then you can install the `huggingface_hub` that comes with a CLI that will be used for the authentication with the token generated in advance. So that then the token can be safely retrieved via `huggingface_hub.get_token`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8d836e0210fe"
      },
      "outputs": [],
      "source": [
        "from huggingface_hub import interpreter_login\n",
        "\n",
        "interpreter_login()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "c71a4314c250"
      },
      "source": [
        "Read more about [Hugging Face Security](https://huggingface.co/docs/hub/en/security), specifically about [Hugging Face User Access Tokens](https://huggingface.co/docs/hub/en/security-tokens)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DF4l8DTdWgPY"
      },
      "source": [
        "### Set Google Cloud project information and initialize Vertex AI SDK\n",
        "\n",
        "To get started using Vertex AI, you must have an existing Google Cloud project and [enable the Vertex AI API](https://console.cloud.google.com/flows/enableapi?apiid=aiplatform.googleapis.com), if not enabled already.\n",
        "\n",
        "Learn more about [setting up a project and a development environment](https://cloud.google.com/vertex-ai/docs/start/cloud-environment)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Nqwi-5ufWp_B"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "\n",
        "from google.cloud import aiplatform\n",
        "\n",
        "PROJECT_ID = \"[your-project-id]\"  # @param {type:\"string\", isTemplate: true}\n",
        "if PROJECT_ID == \"[your-project-id]\":\n",
        "    PROJECT_ID = str(os.environ.get(\"GOOGLE_CLOUD_PROJECT\"))\n",
        "\n",
        "LOCATION = os.environ.get(\"GOOGLE_CLOUD_REGION\", \"us-central1\")\n",
        "\n",
        "aiplatform.init(project=PROJECT_ID, location=LOCATION)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ee37e1544281"
      },
      "source": [
        "## Requirements"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "877cd3fb2dce"
      },
      "source": [
        "You will need to have the following IAM roles set:\n",
        "\n",
        "- Artifact Registry Reader (roles/artifactregistry.reader)\n",
        "- Vertex AI User (roles/aiplatform.user)\n",
        "\n",
        "For more information about granting roles, see [Manage access](https://cloud.google.com/iam/docs/granting-changing-revoking-access).\n",
        "\n",
        "---\n",
        "\n",
        "You will also need to enable the following APIs (if not enabled already):\n",
        "\n",
        "- Vertex AI API (aiplatform.googleapis.com)\n",
        "- Artifact Registry API (artifactregistry.googleapis.com)\n",
        "\n",
        "For more information about API enablement, see [Enabling APIs](https://cloud.google.com/apis/docs/getting-started#enabling_apis).\n",
        "\n",
        "---\n",
        "\n",
        "To access Gemma on Hugging Face, you’re required to review and agree to Google’s usage license on the Hugging Face Hub for any of the models from the [Gemma release collection](https://huggingface.co/collections/google/gemma-release-65d5efbccdbb8c4202ec078b), and the access request will be processed inmediately."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EdvJRUWRNGHE"
      },
      "source": [
        "## Register Google Gemma on Vertex AI"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "28d24c673821"
      },
      "source": [
        "To serve Gemma with Text Generation Inference (TGI) on Vertex AI, you start importing the model on Vertex AI Model Registry.\n",
        "\n",
        "The Vertex AI Model Registry is a central repository where you can manage the lifecycle of your ML models. From the Model Registry, you have an overview of your models so you can better organize, track, and train new versions. When you have a model version you would like to deploy, you can assign it to an endpoint directly from the registry, or using aliases, deploy models to an endpoint. The models on the Vertex AI Model Registry just contain the model configuration and not the weights per se, as it's not a storage.\n",
        "\n",
        "Before going into the code to upload or import a model on Vertex AI, let's quickly review the arguments provided to the `aiplatform.Model.upload` method:\n",
        "\n",
        "* **`display_name`** is the name that will be shown in the Vertex AI Model Registry.\n",
        "\n",
        "* **`serving_container_image_uri`** is the location of the Hugging Face DLC for TGI that will be used for serving the model.\n",
        "\n",
        "* **`serving_container_environment_variables`** are the environment variables that will be used during the container runtime, so these are aligned with the environment variables defined by `text-generation-inference`, which are analog to the [`text-generation-launcher` arguments](https://huggingface.co/docs/text-generation-inference/en/basic_tutorials/launcher). Additionally, the Hugging Face DLCs for TGI also capture the `AIP_` environment variables from Vertex AI as in [Vertex AI Documentation - Custom container requirements for prediction](https://cloud.google.com/vertex-ai/docs/predictions/custom-container-requirements).\n",
        "\n",
        "    * `MODEL_ID` is the identifier of the model in the Hugging Face Hub. To explore all the supported models you can check [the Hugging Face Hub](https://huggingface.co/models?sort=trending&other=text-generation-inference).\n",
        "    * `NUM_SHARD` is the number of shards to use if you don't want to use all GPUs on a given machine e.g. if you have two GPUs but you just want to use one for TGI then `NUM_SHARD=1`, otherwise it matches the `CUDA_VISIBLE_DEVICES`.\n",
        "    * `MAX_INPUT_TOKENS` is the maximum allowed input length (expressed in number of tokens), the larger it is, the larger the prompt can be, but also more memory will be consumed.\n",
        "    * `MAX_TOTAL_TOKENS` is the most important value to set as it defines the \"memory budget\" of running clients requests, the larger this value, the larger amount each request will be in your RAM and the less effective batching can be.\n",
        "    * `MAX_BATCH_PREFILL_TOKENS` limits the number of tokens for the prefill operation, as it takes the most memory and is compute bound, it is interesting to limit the number of requests that can be sent.\n",
        "    * `HUGGING_FACE_HUB_TOKEN` is the Hugging Face Hub token, required as [`google/gemma-7b-it`](https://huggingface.co/google/gemma-7b-it) is a gated model.\n",
        "\n",
        "* (optional) **`serving_container_ports`** is the port where the Vertex AI endpoint |will be exposed, by default 8080.\n",
        "\n",
        "For more information on the supported `aiplatform.Model.upload` arguments, check [its Python reference](https://cloud.google.com/python/docs/reference/aiplatform/latest/google.cloud.aiplatform.Model#google_cloud_aiplatform_Model_upload)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ef487ce082f3"
      },
      "outputs": [],
      "source": [
        "from huggingface_hub import get_token\n",
        "\n",
        "model = aiplatform.Model.upload(\n",
        "    display_name=\"google--gemma-7b-it\",\n",
        "    serving_container_image_uri=\"us-docker.pkg.dev/deeplearning-platform-release/gcr.io/huggingface-text-generation-inference-cu121.2-2.ubuntu2204.py310\",\n",
        "    serving_container_environment_variables={\n",
        "        \"MODEL_ID\": \"google/gemma-7b-it\",\n",
        "        \"NUM_SHARD\": \"1\",\n",
        "        \"MAX_INPUT_TOKENS\": \"512\",\n",
        "        \"MAX_TOTAL_TOKENS\": \"1024\",\n",
        "        \"MAX_BATCH_PREFILL_TOKENS\": \"1512\",\n",
        "        \"HUGGING_FACE_HUB_TOKEN\": get_token(),\n",
        "    },\n",
        "    serving_container_ports=[8080],\n",
        ")\n",
        "model.wait()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "c427c0a87016"
      },
      "source": [
        "## Deploy Google Gemma on Vertex AI"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ef4cd160dc09"
      },
      "source": [
        "After the model is registered on Vertex AI, you can deploy the model to an endpoint.\n",
        "\n",
        "You need to first deploy a model to an endpoint before that model can be used to serve online predictions. Deploying a model associates physical resources with the model so it can serve online predictions with low latency.\n",
        "\n",
        "Before going into the code to deploy a model to an endpoint, let's quickly review the arguments provided to the `aiplatform.Model.deploy` method:\n",
        "\n",
        "- **`endpoint`** is the endpoint to deploy the model to, which is optional, and by default will be set to the model display name with the `_endpoint` suffix.\n",
        "- **`machine_type`**, **`accelerator_type`** and **`accelerator_count`** are arguments that define which instance to use, and additionally, the accelerator to use and the number of accelerators, respectively. The `machine_type` and the `accelerator_type` are tied together, so you will need to select an instance that supports the accelerator that you are using and vice-versa. More information about the different instances at [Compute Engine Documentation - GPU machine types](https://cloud.google.com/compute/docs/gpus), and about the `accelerator_type` naming at [Vertex AI Documentation - MachineSpec](https://cloud.google.com/vertex-ai/docs/reference/rest/v1/MachineSpec).\n",
        "\n",
        "For more information on the supported `aiplatform.Model.deploy` arguments, you can check [its Python reference](https://cloud.google.com/python/docs/reference/aiplatform/latest/google.cloud.aiplatform.Model#google_cloud_aiplatform_Model_deploy)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "e777427015db"
      },
      "outputs": [],
      "source": [
        "deployed_model = model.deploy(\n",
        "    endpoint=aiplatform.Endpoint.create(display_name=\"google--gemma-7b-it-endpoint\"),\n",
        "    machine_type=\"g2-standard-4\",\n",
        "    accelerator_type=\"NVIDIA_L4\",\n",
        "    accelerator_count=1,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5f2380d3ea60"
      },
      "source": [
        "> Note that the model deployment on Vertex AI can take around 15 to 25 minutes; most of the time being the allocation / reservation of the resources, setting up the network and security, and such."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9b3fd1898241"
      },
      "source": [
        "## Online predictions on Vertex AI"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2aa9cee03bd0"
      },
      "source": [
        "Once the model is deployed on Vertex AI, you can run the online predictions using the `aiplatform.Endpoint.predict` method, which will send the requests to the running endpoint in the `/predict` route specified within the container following Vertex AI I/O payload formatting.\n",
        "\n",
        "As you are serving a `text-generation` model fine-tuned for instruction-following, you will need to make sure that the chat template, if any, is applied correctly to the input conversation; meaning that `transformers` need to be installed so as to instantiate the `tokenizer` for [`google/gemma-7b-it`](https://huggingface.co/google/gemma-7b-it) and run the `apply_chat_template` method over the input conversation before sending the input within the payload to the Vertex AI endpoint.\n",
        "\n",
        "> Note that the Messages API will be supported on Vertex AI on upcoming TGI releases, starting on 2.3, meaning that at the time of writing this post, the prompts need to be formatted before sending the request if you want to achieve nice results (assuming you are using an intruction-following model and not a base model)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "d46c0a866ffc"
      },
      "outputs": [],
      "source": [
        "%pip install --upgrade --user --quiet transformers jinja2"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7e1ab8de0b5c"
      },
      "source": [
        "After the installation is complete, the following snippet will apply the chat template to the conversation:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "82fd1842f0cc"
      },
      "outputs": [],
      "source": [
        "from huggingface_hub import get_token\n",
        "from transformers import AutoTokenizer\n",
        "\n",
        "tokenizer = AutoTokenizer.from_pretrained(\"google/gemma-7b-it\", token=get_token())\n",
        "\n",
        "messages = [\n",
        "    {\"role\": \"user\", \"content\": \"What's Deep Learning?\"},\n",
        "]\n",
        "\n",
        "inputs = tokenizer.apply_chat_template(\n",
        "    messages,\n",
        "    tokenize=False,\n",
        "    add_generation_prompt=True,\n",
        ")\n",
        "# <bos><start_of_turn>user\\nWhat's Deep Learning?<end_of_turn>\\n<start_of_turn>model\\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "d0522c9b9bc2"
      },
      "source": [
        "Which is what you will be sending within the payload to the deployed Vertex AI Endpoint, as well as [the generation parameters](https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "79e393e97ddb"
      },
      "source": [
        "### Via Python"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1a404fd19a08"
      },
      "source": [
        "#### Within the same session\n",
        "\n",
        "To run the online prediction via the Vertex AI SDK, you can simply use the `predict` method."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4ab63079a69a"
      },
      "outputs": [],
      "source": [
        "output = deployed_model.predict(\n",
        "    instances=[\n",
        "        {\n",
        "            \"inputs\": \"<bos><start_of_turn>user\\nWhat's Deep Learning?<end_of_turn>\\n<start_of_turn>model\\n\",\n",
        "            \"parameters\": {\n",
        "                \"max_new_tokens\": 20,\n",
        "                \"do_sample\": True,\n",
        "                \"top_p\": 0.95,\n",
        "                \"temperature\": 1.0,\n",
        "            },\n",
        "        },\n",
        "    ]\n",
        ")\n",
        "output"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "b28efb473c7e"
      },
      "source": [
        "#### From a different session\n",
        "\n",
        "To run the online prediction from a different session, you can run the following snippet."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "e003ab03f30f"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "\n",
        "from google.cloud import aiplatform\n",
        "\n",
        "PROJECT_ID = \"[your-project-id]\"  # @param {type:\"string\", isTemplate: true}\n",
        "if PROJECT_ID == \"[your-project-id]\":\n",
        "    PROJECT_ID = str(os.environ.get(\"GOOGLE_CLOUD_PROJECT\"))\n",
        "\n",
        "LOCATION = os.environ.get(\"GOOGLE_CLOUD_REGION\", \"us-central1\")\n",
        "\n",
        "aiplatform.init(project=PROJECT_ID, location=LOCATION)\n",
        "\n",
        "endpoint_display_name = (\n",
        "    \"google--gemma-7b-it-endpoint\"  # TODO: change to your endpoint display name\n",
        ")\n",
        "\n",
        "# Iterates over all the Vertex AI Endpoints within the current project and keeps the first match (if any), otherwise set to None\n",
        "ENDPOINT_ID = next(\n",
        "    (\n",
        "        endpoint.name\n",
        "        for endpoint in aiplatform.Endpoint.list()\n",
        "        if endpoint.display_name == endpoint_display_name\n",
        "    ),\n",
        "    None,\n",
        ")\n",
        "assert ENDPOINT_ID, (\n",
        "    \"`ENDPOINT_ID` is not set, please make sure that the `endpoint_display_name` is correct at \"\n",
        "    f\"https://console.cloud.google.com/vertex-ai/online-prediction/endpoints?project={os.getenv('PROJECT_ID')}\"\n",
        ")\n",
        "\n",
        "endpoint = aiplatform.Endpoint(\n",
        "    f\"projects/{PROJECT_ID}/locations/{LOCATION}/endpoints/{ENDPOINT_ID}\"\n",
        ")\n",
        "output = endpoint.predict(\n",
        "    instances=[\n",
        "        {\n",
        "            \"inputs\": \"<bos><start_of_turn>user\\nWhat's Deep Learning?<end_of_turn>\\n<start_of_turn>model\\n\",\n",
        "            \"parameters\": {\n",
        "                \"max_new_tokens\": 20,\n",
        "                \"do_sample\": True,\n",
        "                \"top_p\": 0.95,\n",
        "                \"temperature\": 0.7,\n",
        "            },\n",
        "        },\n",
        "    ],\n",
        ")\n",
        "output"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7cc0b19ab641"
      },
      "source": [
        "### Via gcloud"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8fed7007c31c"
      },
      "source": [
        "You can also send the requests using the `gcloud` CLI via the `gcloud ai endpoints` command.\n",
        "\n",
        "> Note that, before proceeding, you should either replace the values or set the following environment variables in advance from the Python variables set in the example, as follows:\n",
        ">\n",
        "> ```python\n",
        "> import os\n",
        "> os.environ[\"PROJECT_ID\"] = PROJECT_ID\n",
        "> os.environ[\"LOCATION\"] = LOCATION\n",
        "> os.environ[\"ENDPOINT_NAME\"] = \"google--gemma-7b-it-endpoint\"\n",
        "> ```"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7fb1547d409f"
      },
      "outputs": [],
      "source": [
        "%%bash\n",
        "ENDPOINT_ID=$(gcloud ai endpoints list \\\n",
        "  --project=$PROJECT_ID \\\n",
        "  --region=$LOCATION \\\n",
        "  --filter=\"display_name=$ENDPOINT_NAME\" \\\n",
        "  --format=\"value(name)\" \\\n",
        "  | cut -d'/' -f6)\n",
        "\n",
        "echo '{\n",
        "  \"instances\": [\n",
        "    {\n",
        "      \"inputs\": \"<bos><start_of_turn>user\\nWhat'\\''s Deep Learning?<end_of_turn>\\n<start_of_turn>model\\n\",\n",
        "      \"parameters\": {\n",
        "        \"max_new_tokens\": 20,\n",
        "        \"do_sample\": true,\n",
        "        \"top_p\": 0.95,\n",
        "        \"temperature\": 1.0\n",
        "      }\n",
        "    }\n",
        "  ]\n",
        "}' | gcloud ai endpoints predict $ENDPOINT_ID \\\n",
        "  --project=$PROJECT_ID \\\n",
        "  --region=$LOCATION \\\n",
        "  --json-request=\"-\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7e588a65759e"
      },
      "source": [
        "### Via cURL"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cf3de731bcf4"
      },
      "source": [
        "Alternatively, you can also send the requests via `cURL`.\n",
        "\n",
        "> Note that, before proceeding, you should either replace the values or set the following environment variables in advance from the Python variables set in the example, as follows:\n",
        ">\n",
        "> ```python\n",
        "> import os\n",
        "> os.environ[\"PROJECT_ID\"] = PROJECT_ID\n",
        "> os.environ[\"LOCATION\"] = LOCATION\n",
        "> os.environ[\"ENDPOINT_NAME\"] = \"google--gemma-7b-it-endpoint\"\n",
        "> ```"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "a71da4e1152b"
      },
      "outputs": [],
      "source": [
        "%%bash\n",
        "ENDPOINT_ID=$(gcloud ai endpoints list \\\n",
        "  --project=$PROJECT_ID \\\n",
        "  --region=$LOCATION \\\n",
        "  --filter=\"display_name=$ENDPOINT_NAME\" \\\n",
        "  --format=\"value(name)\" \\\n",
        "  | cut -d'/' -f6)\n",
        "\n",
        "curl -X POST \\\n",
        "    -H \"Authorization: Bearer $(gcloud auth print-access-token)\" \\\n",
        "    -H \"Content-Type: application/json\" \\\n",
        "    https://$LOCATION-aiplatform.googleapis.com/v1/projects/$PROJECT_ID/locations/$LOCATION/endpoints/$ENDPOINT_ID:predict \\\n",
        "    -d '{\n",
        "        \"instances\": [\n",
        "            {\n",
        "                \"inputs\": \"<bos><start_of_turn>user\\nWhat'\\''s Deep Learning?<end_of_turn>\\n<start_of_turn>model\\n\",\n",
        "                \"parameters\": {\n",
        "                    \"max_new_tokens\": 20,\n",
        "                    \"do_sample\": true,\n",
        "                    \"top_p\": 0.95,\n",
        "                    \"temperature\": 1.0\n",
        "                }\n",
        "            }\n",
        "        ]\n",
        "    }'"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "f6f17f9aff65"
      },
      "source": [
        "## Cleaning up"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2a4e033321ad"
      },
      "source": [
        "Finally, you can already release the resources that you've created as follows, to avoid unnecessary costs:\n",
        "\n",
        "- `deployed_model.undeploy_all` to undeploy the model from all the endpoints.\n",
        "- `deployed_model.delete` to delete the endpoint/s where the model was deployed gracefully, after the `undeploy_all` method.\n",
        "- `model.delete` to delete the model from the registry."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "b074ffdbd0c4"
      },
      "outputs": [],
      "source": [
        "deployed_model.undeploy_all()\n",
        "deployed_model.delete()\n",
        "model.delete()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6bc4da2fcc35"
      },
      "source": [
        "Alternatively, you can also remove those from the Google Cloud Console following the steps:\n",
        "\n",
        "- Go to Vertex AI in Google Cloud\n",
        "- Go to Deploy and use -> Online prediction\n",
        "- Click on the endpoint and then on the deployed model/s to \"Undeploy model from endpoint\"\n",
        "- Then go back to the endpoint list and remove the endpoint\n",
        "- Finally, go to Deploy and use -> Model Registry, and remove the model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "137d82da56f2"
      },
      "source": [
        "## References\n",
        "\n",
        "- [GitHub Repository - Hugging Face DLCs for Google Cloud](https://github.com/huggingface/Google-Cloud-Containers): contains all the containers developed by the collaboration of both Hugging Face and Google Cloud teams; as well as a lot of examples on both training and inference, covering both CPU and GPU, as well support for most of the models within the Hugging Face Hub.\n",
        "- [Google Cloud Documentation - Hugging Face DLCs](https://cloud.google.com/deep-learning-containers/docs/choosing-container#hugging-face): contains a table with the latest released Hugging Face DLCs on Google Cloud.\n",
        "- [Google Artifact Registry - Hugging Face DLCs](https://console.cloud.google.com/artifacts/docker/deeplearning-platform-release/us/gcr.io): contains all the DLCs released by Google Cloud that can be used.\n",
        "- [Hugging Face Documentation - Google Cloud](https://huggingface.co/docs/google-cloud): contains the official Hugging Face documentation for the Google Cloud DLCs."
      ]
    }
  ],
  "metadata": {
    "colab": {
      "name": "vertex_ai_text_generation_inference_gemma.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
